Skip to content

[ROCm] Restore 16-wide fast path in Triton unified attention#30582

Open
hyoon1 wants to merge 1 commit intovllm-project:mainfrom
hyoon1:fix-unified-attn-rdna-2
Open

[ROCm] Restore 16-wide fast path in Triton unified attention#30582
hyoon1 wants to merge 1 commit intovllm-project:mainfrom
hyoon1:fix-unified-attn-rdna-2

Conversation

@hyoon1
Copy link
Copy Markdown
Contributor

@hyoon1 hyoon1 commented Dec 13, 2025

Purpose

  • PR [Kernel] Enable Hybrid Model Support in Triton Unified Attention Kernel #21197 decoupled the compute tile from the paged KV cache so the Triton unified-attention kernel could feed wider tiles into hybrid configurations. That forced every iteration to compute tile_mask, seq_offset // BLOCK_SIZE, etc., and for AMD Radeon RDNA 3/4 (gfx11/gfx12) those extra ops caused a significant regression because the hardware WMMA fragment is still 16 columns wide.
  • add a block-aligned fast path to both 2D and 3D Triton unified-attention kernels so that, when TILE_SIZE == BLOCK_SIZE, the kernel reverts to the original contiguous indexing (no div/mod/masking); otherwise it executes the decoupled loop
  • on AMD Radeon 3/4 (gfx11/gfx12) with ≥16-bit queries, set TILE_SIZE_PREFILL = TILE_SIZE_DECODE = block_size so prefills and decodes align with the GPU’s 16-wide WMMA fragment and consistently hit the fast path;
  • keep the widened default tiles (32-column) on other architectures , preserving the hybrid-model flexibility while eliminating the Navi specific regression.

Test Plan

Run meta-llama/Llama-3.1-8B-Instruct and check the benchmark result.

vllm serve meta-llama/Meta-Llama-3.1-8B-Instruct --max-model-len 4096

vllm bench serve --model meta-llama/Meta-Llama-3.1-8B-Instruct --trust-remote-code --dataset-name sharegpt --dataset-path sharegpt/ShareGPT_V3_unfiltered_cleaned_split.json --num-prompts 1000

lm_eval for correctness test

lm_eval --model vllm --model_args pretrained=meta-llama/Meta-Llama-3.1-8B-Instruct --tasks gsm8k --num_fewshot 5 --batch_size auto --limit 500

Test Result

Performance & Correctness

AMD Radeon Pro W7900 (RDNA3)
Original Triton unified attention kernel:

============ Serving Benchmark Result ============
Successful requests:                     1000
Failed requests:                         0
Benchmark duration (s):                  209.30
Total input tokens:                      215196
Total generated tokens:                  197146
Request throughput (req/s):              4.78
Output token throughput (tok/s):         941.95
Peak output token throughput (tok/s):    1627.00
Peak concurrent requests:                1000.00
Total token throughput (tok/s):          1970.14
---------------Time to First Token----------------
Mean TTFT (ms):                          63017.69
Median TTFT (ms):                        55648.32
P99 TTFT (ms):                           153252.16
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          250.99
Median TPOT (ms):                        237.28
P99 TPOT (ms):                           538.50
---------------Inter-token Latency----------------
Mean ITL (ms):                           217.99
Median ITL (ms):                         187.04
P99 ITL (ms):                            561.98
==================================================
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.774|±  |0.0187|
|     |       |strict-match    |     5|exact_match|↑  |0.722|±  |0.0201|

Updated version:

============ Serving Benchmark Result ============
Successful requests:                     1000
Failed requests:                         0
Benchmark duration (s):                  181.08
Total input tokens:                      215196
Total generated tokens:                  197815
Request throughput (req/s):              5.52
Output token throughput (tok/s):         1092.39
Peak output token throughput (tok/s):    1920.00
Peak concurrent requests:                1000.00
Total token throughput (tok/s):          2280.76
---------------Time to First Token----------------
Mean TTFT (ms):                          54988.04
Median TTFT (ms):                        48941.78
P99 TTFT (ms):                           132739.99
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          215.39
Median TPOT (ms):                        202.28
P99 TPOT (ms):                           494.32
---------------Inter-token Latency----------------
Mean ITL (ms):                           185.68
Median ITL (ms):                         153.16
P99 ITL (ms):                            512.45
==================================================
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.788|±  |0.0183|
|     |       |strict-match    |     5|exact_match|↑  |0.730|±  |0.0199|

AMD MI308X
Original Triton unified attention kernel:

============ Serving Benchmark Result ============
Successful requests:                     1000
Failed requests:                         0
Benchmark duration (s):                  58.89
Total input tokens:                      215196
Total generated tokens:                  197124
Request throughput (req/s):              16.98
Output token throughput (tok/s):         3347.18
Peak output token throughput (tok/s):    7271.00
Peak concurrent requests:                1000.00
Total token throughput (tok/s):          7001.23
---------------Time to First Token----------------
Mean TTFT (ms):                          10650.23
Median TTFT (ms):                        10268.59
P99 TTFT (ms):                           21028.96
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          272.65
Median TPOT (ms):                        149.06
P99 TPOT (ms):                           738.67
---------------Inter-token Latency----------------
Mean ITL (ms):                           117.73
Median ITL (ms):                         80.86
P99 ITL (ms):                            743.77
==================================================
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.778|±  |0.0186|
|     |       |strict-match    |     5|exact_match|↑  |0.720|±  |0.0201|

Updated version:

============ Serving Benchmark Result ============
Successful requests:                     1000
Failed requests:                         0
Benchmark duration (s):                  58.82
Total input tokens:                      215196
Total generated tokens:                  197037
Request throughput (req/s):              17.00
Output token throughput (tok/s):         3350.08
Peak output token throughput (tok/s):    6854.00
Peak concurrent requests:                1000.00
Total token throughput (tok/s):          7008.91
---------------Time to First Token----------------
Mean TTFT (ms):                          10385.18
Median TTFT (ms):                        9877.28
P99 TTFT (ms):                           20799.42
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          275.07
Median TPOT (ms):                        150.40
P99 TPOT (ms):                           738.53
---------------Inter-token Latency----------------
Mean ITL (ms):                           118.42
Median ITL (ms):                         82.08
P99 ITL (ms):                            747.28
==================================================
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.782|±  |0.0185|
|     |       |strict-match    |     5|exact_match|↑  |0.726|±  |0.0200|
  • AMD Radeon Pro W7900 (RDNA3): Observed a significant boost in performance, with output token throughput increasing by approximately 16%.
  • AMD MI308X (Other GPU): Throughput remains comparable to the baseline with no significant difference.
  • Accuracy: Both environments show a slight improvement in accuracy.

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

@hyoon1 hyoon1 requested a review from tdoublep as a code owner December 13, 2025 01:02
@hyoon1 hyoon1 changed the title restore 16-wide fast path in Triton unified attention [ROCm] restore 16-wide fast path in Triton unified attention Dec 13, 2025
@hyoon1 hyoon1 changed the title [ROCm] restore 16-wide fast path in Triton unified attention [ROCm] Restore 16-wide fast path in Triton unified attention Dec 13, 2025
@mergify mergify bot added the rocm Related to AMD ROCm label Dec 13, 2025
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a performance optimization for AMD Navi GPUs by adding a fast path to Triton attention kernels. This is achieved by creating a special case for when TILE_SIZE equals BLOCK_SIZE, which avoids expensive division and modulo operations. The changes are effective, as shown by the performance benchmarks. However, the implementation introduces significant code duplication in both kernel_unified_attention_2d and kernel_unified_attention_3d kernels. This duplication makes the code harder to maintain and reason about. My review focuses on refactoring these kernels to eliminate the code duplication while preserving the performance benefits.

Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment on lines +255 to +258
# V : (BLOCK_SIZE, HEAD_SIZE)
V_load = tl.load(
value_cache_ptr + v_offset, mask=dim_mask[None, :], other=0.0
)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Mask padded V tiles in block-aligned fast path

When TILE_SIZE == BLOCK_SIZE, the new fast path loads the entire cache block of V with only a head-dimension mask (tl.load(..., mask=dim_mask[None, :])). For sequences whose final KV block is only partially filled, the padding entries remain whatever was previously in that page; if any of those bytes happen to be NaN, the subsequent tl.dot(P, V) will multiply zeros by NaNs and propagate NaNs into acc even though the positions are softmax-masked. The general path used tile_mask to zero out those columns, avoiding reads of uninitialized padding. The fast path needs the same masking (likewise in the new 3D fast path) to prevent corrupted outputs on partially filled blocks.

Useful? React with 👍 / 👎.

@hyoon1 hyoon1 force-pushed the fix-unified-attn-rdna-2 branch 2 times, most recently from 9a30335 to f18b504 Compare December 13, 2025 02:08
@ApostaC
Copy link
Copy Markdown
Collaborator

ApostaC commented Dec 16, 2025

Hey @tdoublep , can you help review this PR? Thanks!

Signed-off-by: Hosang Yoon <hosang.yoon@amd.com>
@hyoon1 hyoon1 force-pushed the fix-unified-attn-rdna-2 branch from f18b504 to 5837be3 Compare February 5, 2026 22:23
@mergify mergify bot added the v1 label Feb 5, 2026
@github-project-automation github-project-automation bot moved this to Todo in AMD Feb 5, 2026
@hyoon1
Copy link
Copy Markdown
Contributor Author

hyoon1 commented Feb 5, 2026

Hi @tdoublep Could you review this? I rebased onto latest main. cc @gshtras @tjtanaa

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

rocm Related to AMD ROCm v1

Projects

Status: Todo

Development

Successfully merging this pull request may close these issues.

2 participants